"""Integration with the `mlx_lm` library."""

from functools import singledispatchmethod
from typing import TYPE_CHECKING, Iterator, List, Optional

from outlines.inputs import Chat
from outlines.models.base import Model, ModelTypeAdapter
from outlines.models.transformers import TransformerTokenizer
from outlines.processors import OutlinesLogitsProcessor

if TYPE_CHECKING:
    import mlx.nn as nn
    from transformers import PreTrainedTokenizer

__all__ = ["MLXLM", "from_mlxlm"]


class MLXLMTypeAdapter(ModelTypeAdapter):
    """Type adapter for the `MLXLM` model."""

    def __init__(self, **kwargs):
        self.tokenizer = kwargs.get("tokenizer")

    @singledispatchmethod
    def format_input(self, model_input):
        """Generate the prompt argument to pass to the model.

        Parameters
        ----------
        model_input
            The input provided by the user.

        Returns
        -------
        str
            The formatted input to be passed to the model.

        """
        raise NotImplementedError(
            f"The input type {type(model_input)} is not available with "
            "mlx-lm. The available types are `str` and `Chat`."
        )

    @format_input.register(str)
    def format_str_input(self, model_input: str):
        return model_input

    @format_input.register(Chat)
    def format_chat_input(self, model_input: Chat) -> str:
        if not all(
            isinstance(message["content"], str)
            for message in model_input.messages
        ):
            raise ValueError(
                "mlx-lm does not support multi-modal messages."
                + "The content of each message must be a string."
            )

        return self.tokenizer.apply_chat_template(
            model_input.messages,
            tokenize=False,
            add_generation_prompt=True,
        )

    def format_output_type(
        self, output_type: Optional[OutlinesLogitsProcessor] = None,
    ) -> Optional[List[OutlinesLogitsProcessor]]:
        """Generate the logits processor argument to pass to the model.

        Parameters
        ----------
        output_type
            The logits processor provided.

        Returns
        -------
        Optional[list[OutlinesLogitsProcessor]]
            The logits processor argument to be passed to the model.

        """
        if not output_type:
            return None
        return [output_type]


class MLXLM(Model):
    """Thin wrapper around an `mlx_lm` model.

    This wrapper is used to convert the input and output types specified by the
    users at a higher level to arguments to the `mlx_lm` library.

    """

    tensor_library_name = "mlx"

    def __init__(
        self,
        model: "nn.Module",
        tokenizer: "PreTrainedTokenizer",
    ):
        """
        Parameters
        ----------
        model
            An instance of an `mlx_lm` model.
        tokenizer
            An instance of an `mlx_lm` tokenizer or of a compatible
            `transformers` tokenizer.

        """
        self.model = model
        # self.mlx_tokenizer is used by the mlx-lm in its generate function
        self.mlx_tokenizer = tokenizer
        # self.tokenizer is used by the logits processor
        self.tokenizer = TransformerTokenizer(tokenizer._tokenizer)
        self.type_adapter = MLXLMTypeAdapter(tokenizer=tokenizer)

    def generate(
        self,
        model_input: str,
        output_type: Optional[OutlinesLogitsProcessor] = None,
        **kwargs,
    ) -> str:
        """Generate text using `mlx-lm`.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The logits processor the model will use to constrain the format of
            the generated text.
        kwargs
            Additional keyword arguments to pass to the `mlx-lm` library.

        Returns
        -------
        str
            The text generated by the model.

        """
        from mlx_lm import generate

        return generate(
            self.model,
            self.mlx_tokenizer,
            self.type_adapter.format_input(model_input),
            logits_processors=self.type_adapter.format_output_type(output_type),
            **kwargs,
        )

    def generate_batch(
        self,
        model_input: list[str],
        output_type: Optional[OutlinesLogitsProcessor] = None,
        **kwargs,
    ) -> list[str]:
        """Generate a batch of text using `mlx-lm`.

        Parameters
        ----------
        model_input
            The list of prompts based on which the model will generate a response.
        output_type
            The logits processor the model will use to constrain the format of
            the generated text.
        kwargs
            Additional keyword arguments to pass to the `mlx-lm` library.

        Returns
        -------
        list[str]
            The list of text generated by the model.

        """
        from mlx_lm import batch_generate

        if output_type:
            raise NotImplementedError(
                "mlx-lm does not support constrained generation with batching."
                + "You cannot provide an `output_type` with this method."
            )

        model_input = [self.type_adapter.format_input(item) for item in model_input]

        # Contrarily to the other generate methods, batch_generate requires
        # tokenized prompts
        add_special_tokens = [
            (
                self.mlx_tokenizer.bos_token is None
                or not prompt.startswith(self.mlx_tokenizer.bos_token)
            )
            for prompt in model_input
        ]
        tokenized_model_input = [
            self.mlx_tokenizer.encode(
                model_input[i], add_special_tokens=add_special_tokens[i]
            )
            for i in range(len(model_input))
        ]

        response = batch_generate(
            self.model,
            self.mlx_tokenizer,
            tokenized_model_input,
            **kwargs,
        )

        return response.texts

    def generate_stream(
        self,
        model_input: str,
        output_type: Optional[OutlinesLogitsProcessor] = None,
        **kwargs,
    ) -> Iterator[str]:
        """Stream text using `mlx-lm`.

        Parameters
        ----------
        model_input
            The prompt based on which the model will generate a response.
        output_type
            The logits processor the model will use to constrain the format of
            the generated text.
        kwargs
            Additional keyword arguments to pass to the `mlx-lm` library.

        Returns
        -------
        Iterator[str]
            An iterator that yields the text generated by the model.

        """
        from mlx_lm import stream_generate

        for gen_response in stream_generate(
            self.model,
            self.mlx_tokenizer,
            self.type_adapter.format_input(model_input),
            logits_processors=self.type_adapter.format_output_type(output_type),
            **kwargs,
        ):
            yield gen_response.text


def from_mlxlm(model: "nn.Module", tokenizer: "PreTrainedTokenizer") -> MLXLM:
    """Create an Outlines `MLXLM` model instance from an `mlx_lm` model and a
    tokenizer.

    Parameters
    ----------
    model
        An instance of an `mlx_lm` model.
    tokenizer
        An instance of an `mlx_lm` tokenizer or of a compatible
        transformers tokenizer.

    Returns
    -------
    MLXLM
        An Outlines `MLXLM` model instance.

    """
    return MLXLM(model, tokenizer)
